{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "9f0d0f32-23b4-41a6-b364-579da297c326"
      },
      "outputs": [],
      "source": [
        "# @title Copyright & License (click to expand)\n",
        "# Copyright 2025 Google LLC\n",
        "#\n",
        "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "#     https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "dd53d60c-97eb-4c72-91ea-f274a753ab34"
      },
      "source": [
        "# Video Data Curation - Video Quality Filtering\n",
        "\n",
        "<table align=\"left\">\n",
        "  <td style=\"text-align: center\">\n",
        "    <a href=\"https://colab.research.google.com/github/GoogleCloudPlatform/generative-ai/blob/main/gemini/use-cases/multimodal-data-curation/quality-filtering.ipynb\">\n",
        "      <img width=\"32px\" src=\"https://www.gstatic.com/pantheon/images/bigquery/welcome_page/colab-logo.svg\" alt=\"Google Colaboratory logo\"><br> Open in Colab\n",
        "    </a>\n",
        "  </td>\n",
        "  <td style=\"text-align: center\">\n",
        "    <a href=\"https://console.cloud.google.com/vertex-ai/colab/import/https:%2F%2Fraw.githubusercontent.com%2FGoogleCloudPlatform%2Fgenerative-ai%2Fmain%2Fgemini%2Fuse-cases%2Fmultimodal-data-curation%2Fquality-filtering.ipynb\">\n",
        "      <img width=\"32px\" src=\"https://lh3.googleusercontent.com/JmcxdQi-qOpctIvWKgPtrzZdJJK-J3sWE1RsfjZNwshCFgE_9fULcNpuXYTilIR2hjwN\" alt=\"Google Cloud Colab Enterprise logo\"><br> Open in Colab Enterprise\n",
        "    </a>\n",
        "  </td>    \n",
        "  <td style=\"text-align: center\">\n",
        "    <a href=\"https://console.cloud.google.com/vertex-ai/workbench/deploy-notebook?download_url=https://raw.githubusercontent.com/GoogleCloudPlatform/generative-ai/main/gemini/use-cases/multimodal-data-curation/quality-filtering.ipynb\">\n",
        "      <img src=\"https://lh3.googleusercontent.com/UiNooY4LUgW_oTvpsNhPpQzsstV5W8F7rYgxgGBD85cWJoLmrOzhVs_ksK_vgx40SHs7jCqkTkCk=e14-rj-sc0xffffff-h130-w32\" alt=\"Vertex AI logo\"><br> Open in Workbench\n",
        "    </a>\n",
        "  </td>\n",
        "  <td style=\"text-align: center\">\n",
        "    <a href=\"https://github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/use-cases/multimodal-data-curation/quality-filtering.ipynb\">\n",
        "      <img width=\"32px\" src=\"https://raw.githubusercontent.com/primer/octicons/refs/heads/main/icons/mark-github-24.svg\" alt=\"GitHub logo\"><br> View on GitHub\n",
        "    </a>\n",
        "  </td>\n",
        "</table>\n",
        "\n",
        "<div style=\"clear: both;\"></div>\n",
        "\n",
        "<b>Share to:</b>\n",
        "\n",
        "<a href=\"https://www.linkedin.com/sharing/share-offsite/?url=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/use-cases/multimodal-data-curation/quality-filtering.ipynb\" target=\"_blank\">\n",
        "  <img width=\"20px\" src=\"https://upload.wikimedia.org/wikipedia/commons/8/81/LinkedIn_icon.svg\" alt=\"LinkedIn logo\">\n",
        "</a>\n",
        "\n",
        "<a href=\"https://bsky.app/intent/compose?text=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/use-cases/multimodal-data-curation/quality-filtering.ipynb\" target=\"_blank\">\n",
        "  <img width=\"20px\" src=\"https://upload.wikimedia.org/wikipedia/commons/7/7a/Bluesky_Logo.svg\" alt=\"Bluesky logo\">\n",
        "</a>\n",
        "\n",
        "<a href=\"https://twitter.com/intent/tweet?url=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/use-cases/multimodal-data-curation/quality-filtering.ipynb\" target=\"_blank\">\n",
        "  <img width=\"20px\" src=\"https://upload.wikimedia.org/wikipedia/commons/5/5a/X_icon_2.svg\" alt=\"X logo\">\n",
        "</a>\n",
        "\n",
        "<a href=\"https://reddit.com/submit?url=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/use-cases/multimodal-data-curation/quality-filtering.ipynb\" target=\"_blank\">\n",
        "  <img width=\"20px\" src=\"https://redditinc.com/hubfs/Reddit%20Inc/Brand/Reddit_Logo.png\" alt=\"Reddit logo\">\n",
        "</a>\n",
        "\n",
        "<a href=\"https://www.facebook.com/sharer/sharer.php?u=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/use-cases/multimodal-data-curation/quality-filtering.ipynb\" target=\"_blank\">\n",
        "  <img width=\"20px\" src=\"https://upload.wikimedia.org/wikipedia/commons/5/51/Facebook_f_logo_%282019%29.svg\" alt=\"Facebook logo\">\n",
        "</a>            "
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MgVK7IeKpW27"
      },
      "source": [
        "| | |\n",
        "|-|-|\n",
        "|Author(s) | [John Semerdjian]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9ef820fb-1203-4cab-965f-17093a4ba25e"
      },
      "source": [
        "## Overview\n",
        "\n",
        "After we have split and transcoded our initial set of videos, the next step in a video curation pipeline is to filter videos based on their quality. Our goal is to discard as many \"low quality\" clips as possible in order to use our limited modeling compute as efficiently as possible. \"Quality\" in this context refers to a collection of subjective and practical based on the developer's own requirements. The practical requirements here may translate into filters that correspond to video metadata like resolution or the absence of visible text, while subjective quality filters may target the aesthetics of the video itself. Since there is not a single set of filters or thresholds to use for _all_ video data curation pipelines, we will focus on the process of extracting and inferring severals fields from our video collection. The approaches below have been aggregated from multiple research papers for training text-to-video foundation models.\n",
        "\n",
        "This notebook is organized as follows:\n",
        "* Metadata Filters\n",
        "* OCR & Watermark Detection\n",
        "* Aesthetic Scoring\n",
        "* Motion Scores "
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0cbf01f0-5f6e-4bcd-903f-84ccaad5332c"
      },
      "source": [
        "### Install required packages"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "MpDAgOsK6kZn"
      },
      "outputs": [],
      "source": [
        "%pip install --upgrade --user --quiet google-genai google-cloud-storage pyav opencv-python numpy torch transformers pillow"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Moror1y0Qq2z"
      },
      "source": [
        "### Restart runtime (Colab only)\n",
        "\n",
        "To use the newly installed packages, you must restart the runtime on Google Colab."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "4KLm_nKmQtC8"
      },
      "outputs": [],
      "source": [
        "# Automatically restart kernel after installs so that your environment can access the new packages\n",
        "import sys\n",
        "\n",
        "if \"google.colab\" in sys.modules:\n",
        "    from google.colab import auth\n",
        "\n",
        "    auth.authenticate_user()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "b37d4259-7e39-417b-8879-24f7575732c8"
      },
      "source": [
        "## Before you begin\n",
        "\n",
        "### Set your project ID\n",
        "\n",
        "**If you don't know your project ID**, try the following:\n",
        "* Run `gcloud config list`.\n",
        "* Run `gcloud projects list`.\n",
        "* See the support page: [Locate the project ID](https://support.google.com/googleapi/answer/7014113)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "caaf0d7e-c6cb-4e56-af5c-553db5180e00"
      },
      "outputs": [],
      "source": [
        "PROJECT_ID = \"[YOUR_PROJECT_ID]\"  # @param {type:\"string\"}\n",
        "# Set the project id\n",
        "! gcloud config set project {PROJECT_ID}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "054d794d-cd2e-4280-95ac-859b264ea2d6"
      },
      "source": [
        "#### Region\n",
        "\n",
        "You can also change the `REGION` variable used by Vertex AI. Learn more about [Vertex AI regions](https://cloud.google.com/vertex-ai/docs/general/locations)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "metadata": {
        "id": "0121bf60-1acd-4272-afaf-aa54b4ded263"
      },
      "outputs": [],
      "source": [
        "REGION = \"us-central1\"  # @param {type:\"string\"}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "czjH2JfKaGfH"
      },
      "source": [
        "#### Bucket\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "metadata": {
        "id": "c_iZzYtraF3y"
      },
      "outputs": [],
      "source": [
        "BUCKET_NAME = \"[YOUR_BUCKET_NAME]\"  # @param {type:\"string\"}\n",
        "BUCKET_URI = f\"gs://{BUCKET_NAME}\""
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "eac9e842-d225-4876-836f-afdb1937d800"
      },
      "source": [
        "### Authenticate your Google Cloud account\n",
        "\n",
        "Depending on your Jupyter environment, you may have to manually authenticate. Follow the relevant instructions below.\n",
        "\n",
        "**1. Vertex AI Workbench**\n",
        "* Do nothing as you are already authenticated.\n",
        "\n",
        "**2. Local JupyterLab instance, uncomment and run:**"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "metadata": {
        "id": "23082eec-b1bd-4594-b5b5-56fe2b74db6f"
      },
      "outputs": [],
      "source": [
        "# ! gcloud auth login"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "3c20f923-3c46-4d6d-80d2-d7cb22b1a8da"
      },
      "source": [
        "**3. Authenticate your notebook environment**\n",
        "\n",
        "If you are running this notebook on Google Colab, run the cell below to authenticate your environment."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "60302a3f-fad9-452c-8998-a9c9822d2732"
      },
      "outputs": [],
      "source": [
        "from google.colab import auth\n",
        "\n",
        "auth.authenticate_user()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ac33116d-b079-46cb-9614-86326c211e00"
      },
      "source": [
        "**4. Service account or other**\n",
        "* See how to grant Cloud Storage permissions to your service account at https://cloud.google.com/storage/docs/gsutil/commands/iam#ch-examples."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "e6a924d0-a034-4e53-b240-03d356c7b7a6"
      },
      "source": [
        "### Import libraries"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "metadata": {
        "id": "5cbbd54d0e08"
      },
      "outputs": [],
      "source": [
        "import io\n",
        "import json\n",
        "import math\n",
        "\n",
        "from PIL import Image\n",
        "import av\n",
        "import cv2 as cv\n",
        "from google import genai\n",
        "from google.cloud import storage\n",
        "from google.genai import types\n",
        "import numpy as np\n",
        "import torch\n",
        "from torch import nn\n",
        "from transformers import CLIPModel, CLIPProcessor\n",
        "\n",
        "storage_client = storage.Client(project=PROJECT_ID)\n",
        "gemini_client = genai.Client(vertexai=True, project=PROJECT_ID, location=REGION)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "727c349d4873"
      },
      "source": [
        "### Load video paths from Cloud Storage\n",
        "\n",
        "We assume you have video data already stored in a Cloud Storage bucket. We will read a few videos and demonsrate their outputs in this notebook."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 5,
      "metadata": {
        "id": "ba417857ce4d"
      },
      "outputs": [],
      "source": [
        "def load_video_paths(\n",
        "    bucket: str,\n",
        "    num_videos: int,\n",
        ") -> list[str]:\n",
        "    video_blobs = []\n",
        "    for i, blob in enumerate(storage_client.list_blobs(bucket)):\n",
        "        if i >= num_videos:\n",
        "            break\n",
        "        video_blobs.append(blob.name)\n",
        "    return video_blobs"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 6,
      "metadata": {
        "id": "080d79c0121a"
      },
      "outputs": [],
      "source": [
        "videos = load_video_paths(BUCKET_NAME, 3)\n",
        "bucket = storage_client.get_bucket(BUCKET_NAME)\n",
        "blobs = [bucket.get_blob(blob_name=blob_name) for blob_name in videos]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "af0db3e8bc74"
      },
      "source": [
        "## Metadata Filters\n",
        "\n",
        "To extract video metadata we will use [PyAV](https://github.com/PyAV-Org/PyAV), a Python wrapper over the popular video processing software FFmpeg. Identifying the video metadata fields below is important for adhering to downstream modeling requirements. For example, training over very dark or very bright videos may lead to poor modeling performance, and we should therefore exclude them during training. \n",
        "\n",
        "There are several metadata fields to consider but we will only cover the following:\n",
        "* Frames per second (FPS)\n",
        "* Duration (already covered in a previous notebook but we include it here as well)\n",
        "* Resolution\n",
        "* Brightness\n",
        "* Aspect Ratio"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 8,
      "metadata": {
        "id": "31ef66832903"
      },
      "outputs": [],
      "source": [
        "def get_last_timestamp(container: av.container.Container) -> dict:\n",
        "    \"\"\"Gets the last frame timestamp and total frame count from video.\n",
        "\n",
        "    Args:\n",
        "        container: PyAV container\n",
        "\n",
        "    Returns:\n",
        "        dict: Last frame index and duration information\n",
        "\n",
        "    \"\"\"\n",
        "    video_stream = container.streams.video[0]\n",
        "    if video_stream.duration:\n",
        "        duration = float(video_stream.duration * video_stream.time_base)\n",
        "    else:\n",
        "        duration = 0.0\n",
        "\n",
        "    frame_count = 0\n",
        "    for _ in container.decode(video_stream):\n",
        "        frame_count += 1\n",
        "    return {\n",
        "        \"last_frame_index\": frame_count,\n",
        "        \"duration\": np.round(duration, 2).item(),\n",
        "    }\n",
        "\n",
        "\n",
        "def get_average_brightness(\n",
        "    container: av.container.Container,\n",
        "    sample_interval: int = 10,\n",
        ") -> float:\n",
        "    \"\"\"Calculate average brightness based on key frames from the video.\n",
        "\n",
        "    See: https://en.wikipedia.org/wiki/Relative_luminance\n",
        "\n",
        "    Args:\n",
        "        container: PyAV container\n",
        "        sample_interval: Interval at which to sample frames\n",
        "\n",
        "    Returns:\n",
        "        float: Average brightness of the video\n",
        "\n",
        "    \"\"\"\n",
        "    video_stream = container.streams.video[0]\n",
        "\n",
        "    total_frames = 0\n",
        "    for _ in container.decode(video_stream):\n",
        "        total_frames += 1\n",
        "\n",
        "    container.seek(0)\n",
        "    sample_interval = max(1, total_frames // sample_interval)\n",
        "    luminance_values = []\n",
        "\n",
        "    frame_idx = 0\n",
        "    for frame in container.decode(video_stream):\n",
        "        if frame_idx % sample_interval == 0:\n",
        "            frame_array = frame.to_ndarray(format=\"rgb24\")\n",
        "            red = frame_array[..., 0]\n",
        "            green = frame_array[..., 1]\n",
        "            blue = frame_array[..., 2]\n",
        "            luminance = 0.2126 * red + 0.7152 * green + 0.0722 * blue\n",
        "            luminance_values.append(np.mean(luminance))\n",
        "        frame_idx += 1\n",
        "\n",
        "    return np.round(np.mean(luminance_values), 2).item() if luminance_values else 0.0\n",
        "\n",
        "\n",
        "def get_video_metadata(video_file_object: io.BytesIO) -> dict:\n",
        "    \"\"\"Extracts video metadata from a video file object using PyAV.\n",
        "\n",
        "    Args:\n",
        "        video_file_object: A binary file-like object containing the video data\n",
        "\n",
        "    Returns:\n",
        "        dict: Video metadata\n",
        "\n",
        "    \"\"\"\n",
        "    aspect_ratio_map = {\n",
        "        \"16:9\": \"hdtv\",\n",
        "        \"4:3\": \"standard television\",\n",
        "        \"21:9\": \"ultrawide\",\n",
        "        \"3:2\": \"common photography\",\n",
        "        \"1:1\": \"square\",\n",
        "        \"5:4\": \"large format photography\",\n",
        "        \"16:10\": \"computer display\",\n",
        "        \"239:100\": \"anamorphic\",\n",
        "        \"47:20\": \"anamorphic\",\n",
        "        \"12:5\": \"anamorphic\",\n",
        "        \"37:20\": \"common widescreen theatrical\",\n",
        "        \"14:9\": \"cropped standard television\",\n",
        "    }\n",
        "\n",
        "    video_file_object.seek(0)\n",
        "    container = av.open(video_file_object)\n",
        "    video_stream = container.streams.video[0]\n",
        "\n",
        "    height = video_stream.height\n",
        "    width = video_stream.width\n",
        "    avg_fps = float(video_stream.average_rate)\n",
        "\n",
        "    common_divisor = math.gcd(width, height)\n",
        "    aspect_ratio = f\"{width // common_divisor}:{height // common_divisor}\"\n",
        "    aspect_ratio_str = aspect_ratio_map.get(aspect_ratio, aspect_ratio)\n",
        "\n",
        "    duration_info = get_last_timestamp(container)\n",
        "\n",
        "    container.seek(0)\n",
        "    avg_brightness = get_average_brightness(container)\n",
        "    container.close()\n",
        "\n",
        "    return {\n",
        "        \"height\": height,\n",
        "        \"width\": width,\n",
        "        \"avg_fps\": int(avg_fps),\n",
        "        \"aspect_ratio\": aspect_ratio,\n",
        "        \"aspect_ratio_type\": aspect_ratio_str,\n",
        "        \"avg_brightness\": avg_brightness,\n",
        "        **duration_info,\n",
        "    }"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "2977f2ee1ffd"
      },
      "outputs": [],
      "source": [
        "for b in blobs:\n",
        "    with b.open(\"rb\") as f:\n",
        "        video_file_object = io.BytesIO(f.read())\n",
        "    print(f\"{b.name}: {get_video_metadata(video_file_object)}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "d3534ba6501e"
      },
      "source": [
        "## OCR and Watermark Detection\n",
        "\n",
        "Videos with text may or may not be desirable in our curated dataset, depending on the capabilities of the model we want to train. On the other hand, video watermarks may be more consistently undesirable. Either way, we will walk through examples of doing both using a single call to Gemini 2.0 Flash. We will also incorporate structured outputs to separate the text from the watermark content as JSON objects."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 10,
      "metadata": {
        "id": "8c265d131385"
      },
      "outputs": [],
      "source": [
        "def get_text_and_watermarks(video_uri: str) -> dict:\n",
        "    \"\"\"Extracts text, watermarks, and usage metadata video.\n",
        "\n",
        "    Args:\n",
        "        vr: decord.VideoReader\n",
        "\n",
        "    Returns:\n",
        "        dict: Visitable text, watermark descriptions, and usage metadata\n",
        "\n",
        "    \"\"\"\n",
        "    system_instruction = \"\"\"Your task is to carefully watch the provided video and do the following:\n",
        "\n",
        "* Extract all visible text that appears across all frames in the video (e.g. text overlays from a weather forecast, credits at the start or end of a video). For example, if there is text on a t-shirt, traffic signs, or something that is naturally part of the video content, do NOT include it in the response.\n",
        "* Identify and describe any visual watermarks in the video. These are typically logos or other images that are overlaid on the video during the production process.\n",
        "* If no text or watermarks are visible, return 'None' as the response.\n",
        "* Be concise, especially if there is no text or watermarks visible.\n",
        "* Return your response in JSON format with the following fields and nothing else:\n",
        "{\n",
        "    \"text\": \"The text found in the video (if any)\",\n",
        "    \"watermarks\": \"Watermark descriptions found in the video (if any)\"\n",
        "}\n",
        "\"\"\"\n",
        "\n",
        "    response = gemini_client.models.generate_content(\n",
        "        model=\"gemini-2.0-flash-001\",\n",
        "        contents=[\n",
        "            types.Part.from_text(text=\"Watch this video:\"),\n",
        "            types.Part.from_uri(file_uri=video_uri, mime_type=\"video/mp4\"),\n",
        "        ],\n",
        "        config=types.GenerateContentConfig(\n",
        "            system_instruction=system_instruction,\n",
        "            temperature=0.0,\n",
        "            response_mime_type=\"application/json\",\n",
        "            response_schema={\n",
        "                \"type\": \"ARRAY\",\n",
        "                \"items\": {\n",
        "                    \"type\": \"OBJECT\",\n",
        "                    \"properties\": {\n",
        "                        \"text\": {\"type\": \"STRING\"},\n",
        "                        \"watermarks\": {\"type\": \"STRING\"},\n",
        "                    },\n",
        "                },\n",
        "            },\n",
        "        ),\n",
        "    )\n",
        "    response_json = json.loads(response.text)[0]\n",
        "    return {**response_json, **response.usage_metadata.to_json_dict()}"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "4b4429ed3df1"
      },
      "outputs": [],
      "source": [
        "for v in videos:\n",
        "    print(f\"{v}:\")\n",
        "    text_and_watermarks = get_text_and_watermarks(f\"{BUCKET_URI}/{v}\")\n",
        "    text = text_and_watermarks[\"text\"]\n",
        "    watermarks = text_and_watermarks[\"watermarks\"]\n",
        "    print(f\"\\ttext: {text}\")\n",
        "    print(f\"\\twatermarks: {watermarks}\\n\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "d250de6ca196"
      },
      "source": [
        "## Aesthetic Scoring\n",
        "\n",
        "Aesthetic scoring involves using a model to evaluate the quality of sampled frames from our videos. Below we show how to use the [CLIP+MLP Aesthetic Score](https://github.com/christophschuhmann/improved-aesthetic-predictor) model that was train over an annotated dataset of images and aesthetic scores, i.e. a numeric value on a scale of 1 to 10, where the higher the score the better. For more details on the underlying training data and classifier see [LAION-Aesthetics blog](https://laion.ai/blog/laion-aesthetics/). We will sample frames from each video, run inference with the trained CLIP+MLP model above, and average the scores. We recommend figuring out the optimal score threshold to filter by manually inspecting a sample of low scoring videos.\n",
        "\n",
        "Alternatively, we can prompt a native multi-modal model like Gemini 2.x to make a similar classification, and even fine-tune Gemini 2.x using the labelled data from the LAION dataset. Either way, the threshold to use (or classification system) will involve manually inspecting a sample of low scoring predictions.\n",
        "\n",
        "Note: Be sure to download the [`sac+logos+ava1-l14-linearMSE.pth`](https://github.com/christophschuhmann/improved-aesthetic-predictor/raw/refs/heads/main/sac+logos+ava1-l14-linearMSE.pth) model weights before running the code below."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 16,
      "metadata": {
        "id": "fbe0d62a2bd0"
      },
      "outputs": [],
      "source": [
        "class MLP(nn.Module):\n",
        "\n",
        "    def __init__(self, input_size: int, xcol: str = \"emb\", ycol: str = \"avg_rating\"):\n",
        "        super().__init__()\n",
        "        self.input_size = input_size\n",
        "        self.xcol = xcol\n",
        "        self.ycol = ycol\n",
        "        self.layers = nn.Sequential(\n",
        "            nn.Linear(self.input_size, 1024),\n",
        "            nn.Dropout(0.2),\n",
        "            nn.Linear(1024, 128),\n",
        "            nn.Dropout(0.2),\n",
        "            nn.Linear(128, 64),\n",
        "            nn.Dropout(0.1),\n",
        "            nn.Linear(64, 16),\n",
        "            nn.Linear(16, 1),\n",
        "        )\n",
        "\n",
        "    def forward(self, x: torch.Tensor) -> torch.Tensor:\n",
        "        return self.layers(x)\n",
        "\n",
        "\n",
        "def normalized(a, axis: int = -1, order: int = 2) -> np.ndarray:\n",
        "    l2 = np.atleast_1d(np.linalg.norm(a, order, axis))\n",
        "    l2[l2 == 0] = 1\n",
        "    return a / np.expand_dims(l2, axis)\n",
        "\n",
        "\n",
        "def sample_video_frames(\n",
        "    video_file_object: io.BytesIO,\n",
        "    num_frames: int,\n",
        ") -> list[Image.Image]:\n",
        "    \"\"\"Sample frames from video at regular intervals.\n",
        "\n",
        "    Args:\n",
        "        video_file_object: A binary file-like object containing the video data\n",
        "        num_frames: Number of frames to sample\n",
        "\n",
        "    Returns:\n",
        "        List of PIL Images\n",
        "\n",
        "    \"\"\"\n",
        "    video_file_object.seek(0)\n",
        "    container = av.open(video_file_object)\n",
        "    video_stream = container.streams.video[0]\n",
        "\n",
        "    total_frames = 0\n",
        "    for _ in container.decode(video_stream):\n",
        "        total_frames += 1\n",
        "\n",
        "    container.seek(0)\n",
        "    if num_frames == 1:\n",
        "        # Get the middle frame\n",
        "        target_frame = total_frames // 2\n",
        "        frame_idx = 0\n",
        "        for frame in container.decode(video_stream):\n",
        "            if frame_idx == target_frame:\n",
        "                frame_array = frame.to_ndarray(format=\"rgb24\")\n",
        "                container.close()\n",
        "                return [Image.fromarray(frame_array)]\n",
        "            frame_idx += 1\n",
        "    else:\n",
        "        # Sample frames at regular intervals, use 80% of frames\n",
        "        last_frame_index = int(total_frames * 0.8)\n",
        "        frame_indices = np.linspace(0, last_frame_index, num_frames).astype(int)\n",
        "\n",
        "        sampled_frames = []\n",
        "        frame_idx = 0\n",
        "        for frame in container.decode(video_stream):\n",
        "            if frame_idx in frame_indices:\n",
        "                frame_array = frame.to_ndarray(format=\"rgb24\")\n",
        "                sampled_frames.append(Image.fromarray(frame_array))\n",
        "            frame_idx += 1\n",
        "            if len(sampled_frames) >= num_frames:\n",
        "                break\n",
        "\n",
        "        container.close()\n",
        "        return sampled_frames\n",
        "\n",
        "    container.close()\n",
        "    return []\n",
        "\n",
        "\n",
        "def get_frames(blob: storage.blob.Blob) -> list[Image.Image]:\n",
        "    \"\"\"Extract key frames from video blob for aesthetic scoring.\n",
        "\n",
        "    Args:\n",
        "        blob: Cloud Storage blob object\n",
        "\n",
        "    Returns:\n",
        "        List of PIL Images\n",
        "\n",
        "    \"\"\"\n",
        "    with blob.open(\"rb\") as f:\n",
        "        video_file_object = io.BytesIO(f.read())\n",
        "\n",
        "    video_file_object.seek(0)\n",
        "    container = av.open(video_file_object)\n",
        "    video_stream = container.streams.video[0]\n",
        "\n",
        "    # Sample every nth frame to get key frames (up to 10 frames)\n",
        "    total_frames = 0\n",
        "    for _ in container.decode(video_stream):\n",
        "        total_frames += 1\n",
        "\n",
        "    container.seek(0)\n",
        "\n",
        "    # Sample frames at regular intervals\n",
        "    sample_interval = max(1, total_frames // 10)\n",
        "    sampled_frames = []\n",
        "\n",
        "    frame_idx = 0\n",
        "    for frame in container.decode(video_stream):\n",
        "        if frame_idx % sample_interval == 0:\n",
        "            frame_array = frame.to_ndarray(format=\"rgb24\")\n",
        "            sampled_frames.append(Image.fromarray(frame_array))\n",
        "        frame_idx += 1\n",
        "\n",
        "    container.close()\n",
        "    return sampled_frames\n",
        "\n",
        "\n",
        "def load_model(model_path: str) -> tuple[CLIPModel, CLIPProcessor, MLP, str]:\n",
        "    mlp_model = MLP(768)\n",
        "    device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
        "\n",
        "    state_dict = torch.load(model_path, map_location=torch.device(device))\n",
        "    mlp_model.load_state_dict(state_dict)\n",
        "    mlp_model.to(device)\n",
        "    mlp_model.eval()\n",
        "\n",
        "    clip_model_name = \"openai/clip-vit-large-patch14\"\n",
        "    processor = CLIPProcessor.from_pretrained(clip_model_name)\n",
        "    clip_model = CLIPModel.from_pretrained(clip_model_name).to(device)\n",
        "    return clip_model, processor, mlp_model, device\n",
        "\n",
        "\n",
        "def get_aesthetic_score(\n",
        "    clip_model: CLIPModel,\n",
        "    processor: CLIPProcessor,\n",
        "    mlp_model: MLP,\n",
        "    device: str,\n",
        "    pil_images: Image.Image,\n",
        ") -> float:\n",
        "    inputs = processor(\n",
        "        text=None,\n",
        "        images=pil_images,\n",
        "        return_tensors=\"pt\",\n",
        "        padding=True,\n",
        "    ).to(device)\n",
        "\n",
        "    with torch.no_grad():\n",
        "        image_features = clip_model.get_image_features(\n",
        "            pixel_values=inputs[\"pixel_values\"],\n",
        "        )\n",
        "\n",
        "    im_emb_arr = normalized(image_features.cpu().detach().numpy())\n",
        "\n",
        "    input_tensor = torch.from_numpy(im_emb_arr).to(device).float()\n",
        "    return mlp_model(input_tensor).mean()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "32871eb5f38f"
      },
      "outputs": [],
      "source": [
        "clip_model, processor, mlp_model, device = load_model(\n",
        "    \"./sac+logos+ava1-l14-linearMSE.pth\"\n",
        ")\n",
        "\n",
        "\n",
        "for b in blobs:\n",
        "    pil_images = get_frames(b)\n",
        "    aesthetic_score = get_aesthetic_score(\n",
        "        clip_model, processor, mlp_model, device, pil_images\n",
        "    )\n",
        "    print(f\"{b.name}: {aesthetic_score}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8176f9f4a1e6"
      },
      "source": [
        "## Motion Scores\n",
        "\n",
        "In order to detect different types of motion we can measure the difference between consective frames in a video. Since we only care about measurement movement in the context of slow or very fast motion, we will focus on more coarse and computationally efficient approaches that fall under *Sparse Optical Flow Estimation*. These techniques look at displacement in important patches of an image, in contrast to *Dense Optical Flow Estimation*, which examines all pixel values ([RAFT (Recurrent All-pairs Field Transforms)](https://arxiv.org/pdf/2003.12039) is one example of a Dense Optical Flow model). The key method in OpenCV is `cv.calcOpticalFlowPyrLK`, which uses the Lucas-Kanade method to calculate optical flow. See the [OpenCV documentation](https://docs.opencv.org/4.x/db/d7f/tutorial_js_lucas_kanade.html) for more details."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 83,
      "metadata": {
        "id": "ef7263b4525f"
      },
      "outputs": [],
      "source": [
        "def estimate_video_motion(video_file_object: io.BytesIO) -> float:\n",
        "    \"\"\"Estimates motion in a video.\n",
        "\n",
        "    Args:\n",
        "        video_file_object: A binary file-like object containing the video data\n",
        "\n",
        "    Returns:\n",
        "        A single float representing the estimated average motion per frame.\n",
        "\n",
        "    Raises:\n",
        "        ValueError: If the video data cannot be read or is empty.\n",
        "\n",
        "    \"\"\"\n",
        "    feature_params = {\n",
        "        \"maxCorners\": 100,\n",
        "        \"qualityLevel\": 0.3,\n",
        "        \"minDistance\": 7,\n",
        "        \"blockSize\": 7,\n",
        "    }\n",
        "    lk_params = {\n",
        "        \"winSize\": (15, 15),\n",
        "        \"maxLevel\": 2,\n",
        "        \"criteria\": (cv.TERM_CRITERIA_EPS | cv.TERM_CRITERIA_COUNT, 10, 0.03),\n",
        "    }\n",
        "\n",
        "    motion_per_frame = []\n",
        "\n",
        "    try:\n",
        "        video_file_object.seek(0)\n",
        "        container = av.open(video_file_object)\n",
        "        stream = container.streams.video[0]\n",
        "        frames = container.decode(stream)\n",
        "        # Get the first frame\n",
        "        old_frame = next(frames)\n",
        "        old_frame_bgr = old_frame.to_ndarray(format=\"bgr24\")\n",
        "    except Exception as e:\n",
        "        raise ValueError(f\"Failed to read video data from memory. Error: {e}\")\n",
        "\n",
        "    old_gray = cv.cvtColor(old_frame_bgr, cv.COLOR_BGR2GRAY)\n",
        "    p0 = cv.goodFeaturesToTrack(old_gray, mask=None, **feature_params)\n",
        "\n",
        "    if p0 is None:\n",
        "        return 0.0\n",
        "\n",
        "    for frame in frames:\n",
        "        frame_bgr = frame.to_ndarray(format=\"bgr24\")\n",
        "        frame_gray = cv.cvtColor(frame_bgr, cv.COLOR_BGR2GRAY)\n",
        "        p1, st, err = cv.calcOpticalFlowPyrLK(\n",
        "            old_gray, frame_gray, p0, None, **lk_params\n",
        "        )\n",
        "\n",
        "        if p1 is not None and st is not None:\n",
        "            good_new = p1[st == 1]\n",
        "            good_old = p0[st == 1]\n",
        "\n",
        "            if len(good_new) > 0:\n",
        "                distances = np.linalg.norm(good_new - good_old, axis=1)\n",
        "                motion_per_frame.append(np.mean(distances))\n",
        "\n",
        "            p0 = good_new.reshape(-1, 1, 2)\n",
        "        else:\n",
        "            p0 = cv.goodFeaturesToTrack(frame_gray, mask=None, **feature_params)\n",
        "            if p0 is None:\n",
        "                break\n",
        "\n",
        "        old_gray = frame_gray.copy()\n",
        "\n",
        "    return np.mean(motion_per_frame) if motion_per_frame else 0.0"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "31597ed5a055"
      },
      "outputs": [],
      "source": [
        "for b in blobs:\n",
        "    with b.open(\"rb\") as f:\n",
        "        video_file_object = io.BytesIO(f.read())\n",
        "    motion_score = estimate_video_motion(video_file_object)\n",
        "    print(f\"{b.name}: {motion_score}\")"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "name": "quality-filtering.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
